13c8e4
@@ -28,6 +28,7 @@
import java.util.ListIterator;
 import java.util.Map;
 import java.util.Set;
 
+import org.apache.commons.collections.bag.HashBag;
 import org.apache.commons.collections.collection.PredicatedCollection;
 import org.apache.commons.collections.collection.SynchronizedCollection;
 import org.apache.commons.collections.collection.TransformedCollection;
@@ -245,6 +246,12 @@
public class CollectionUtils {
      * Returns a new {@link Collection} containing <i>a</i> minus a subset of
      * <i>b</i>.  Only the elements of <i>b</i> that satisfy the predicate
      * condition, <i>p</i> are subtracted from <i>a</i>.
+     * 
+     * <p>The cardinality of each element <i>e</i> in the returned {@link Collection}
+     * that satisfies the predicate condition will be the cardinality of <i>e</i> in <i>a</i>
+     * minus the cardinality of <i>e</i> in <i>b</i>, or zero, whichever is greater.</p>
+     * <p>The cardinality of each element <i>e</i> in the returned {@link Collection} that does <b>not</b>
+     * satisfy the predicate condition will be equal to the cardinality of <i>e</i> in <i>a</i>.</p>
      *
      * @param a  the collection to subtract from, must not be null
      * @param b  the collection to subtract, must not be null
@@ -256,12 +263,19 @@
public class CollectionUtils {
      * @since 4.0
      * @see Collection#removeAll
      */
-    public static <O> Collection<O> subtract(final Iterable<? extends O> a, final Iterable<? extends O> b, final Predicate<O> p) {
-        ArrayList<O> list = new ArrayList<O>();
-        addAll(list, a);
+    public static <O> Collection<O> subtract(final Iterable<? extends O> a,
+                                             final Iterable<? extends O> b,
+                                             final Predicate<O> p) {
+        final ArrayList<O> list = new ArrayList<O>();
+        final HashBag<O> bag = new HashBag<O>();
         for (O element : b) {
             if (p.evaluate(element)) {
-                list.remove(element);
+                bag.add(element);
+            }
+        }
+        for (O element : a) {
+            if (!bag.remove(element, 1)) {
+                list.add(element);
             }
         }
         return list;
